/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "shape.h"
#include <sstream>

namespace ge {
int64_t Shape::GetDim(size_t index) const {
  return dims_[index];
}
int64_t &Shape::GetDim(size_t index) {
  return dims_[index];
}
size_t Shape::GetDimNum() const {
  return rank_;
}
Status Shape::SetDim(size_t index, int64_t dim_value) {
  dims_[index] = dim_value;
  return SUCCESS;
}
bool Shape::IsScaler() const {
  return dims_.empty();
}
int64_t Shape::GetShapeSize() const {
  return GetElementCount();
}
int64_t Shape::GetElementCount() const {
  if (dims_.empty()) {
    return 1;  // TODO 原有数据结构这里返回0，但是我认为做的不对，为空代表标量，那么element个数是一个
  }
  int64_t shape_size = 1;
  for (auto dim : dims_) {
    if (dim == kUnknownDim || dim == kUnknownDimNum) {
      return kUnknownDim;
    }
    shape_size *= dim;  // TODO 这里可能反转，最好考虑一下，不过老代码没考虑，这里也没写
  }
  return shape_size;
}
Status Shape::SetDims(const std::initializer_list<int64_t> &dims) {
  rank_ = 0;
  for (auto dim : dims) {
    dims_[rank_++] = dim;
  }
  return SUCCESS;
}
Shape::Shape(std::initializer_list<int64_t> dims) {
  SetDims(dims);
}
std::string Shape::ToString() const {
  std::stringstream ss;
  for (size_t i = 0; i < rank_; ++i) {
    if (i != 0) {
      ss << ",";
    }
    ss << dims_[i];
  }
  return ss.str();
}
Shape::Shape(const std::vector<int64_t> &dims) {
  rank_ = 0;
  for (auto dim : dims) {
    dims_[rank_++] = dim;
  }
}
const DimsType &Shape::GetDims() const {
  return dims_;
}

std::vector<int64_t> Shape::GetDimsVec() const {
  return std::vector<int64_t>({dims_.begin(), dims_.end()});
}
void Shape::SetDimNum(size_t dim_num) {
  rank_ = dim_num;
}
bool Shape::operator==(const Shape &other) const {
  if (rank_ != other.rank_) {
    return false;
  }
  for (size_t i = 0; i < rank_; ++i) {
    if (dims_[i] != other.dims_[i]) {
      return false;
    }
  }
  return true;
}
bool Shape::operator!=(const Shape &other) const {
  return !(*this == other);
}
void Range::SetDimNum(size_t dim_num) {
  rank_ = dim_num;
}
const std::pair<int64_t, int64_t> &Range::GetDimRange(size_t index) const {
  return range_[index];
}
void Range::SetDimRange(size_t index, std::pair<int64_t, int64_t> dim_range) {
  range_[index] = dim_range;
}
std::pair<int64_t, int64_t> &Range::GetDimRange(size_t index) {
  return range_[index];
}
size_t Range::GetDimsCount() const {
  return rank_;
}
}